{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "“Bi-LSTM(Attention)-Torch.ipynb”的副本",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "gnx-OWyWsdji",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "  Reference : https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM_Attn.py\n",
        "'''\n",
        "import torch\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "import matplotlib.pyplot as plt\n",
        "import torch.utils.data as Data\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oqpm5IqKs2l2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Bi-LSTM(Attention) Parameters\n",
        "batch_size = 3\n",
        "embedding_dim = 2\n",
        "n_hidden = 5 # number of hidden units in one cell\n",
        "num_classes = 2  # 0 or 1\n",
        "\n",
        "# 3 words sentences (=sequence_length is 3)\n",
        "sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n",
        "labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.\n",
        "\n",
        "vocab = list(set(\" \".join(sentences).split()))\n",
        "word2idx = {w: i for i, w in enumerate(vocab)}\n",
        "vocab_size = len(word2idx)\n",
        "\n",
        "def make_data(sentences):\n",
        "  inputs = []\n",
        "  for sen in sentences:\n",
        "      inputs.append(np.asarray([word2idx[n] for n in sen.split()]))\n",
        "\n",
        "  targets = []\n",
        "  for out in labels:\n",
        "      targets.append(out) # To using Torch Softmax Loss function\n",
        "\n",
        "  return torch.LongTensor(inputs), torch.LongTensor(targets)\n",
        "\n",
        "inputs, targets = make_data(sentences)\n",
        "dataset = Data.TensorDataset(inputs, targets)\n",
        "loader = Data.DataLoader(dataset, batch_size, True)"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UKCQaHZTtqT3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BiLSTM_Attention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(BiLSTM_Attention, self).__init__()\n",
        "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
        "        self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)\n",
        "        self.out = nn.Linear(n_hidden * 2, num_classes)\n",
        "\n",
        "    # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix\n",
        "    def attention_net(self, lstm_output, final_state):\n",
        "        batch_size = len(lstm_output)\n",
        "        hidden = final_state.view(batch_size, -1, 1)   # hidden : [batch_size, n_hidden * num_directions(=2), n_layer(=1)]\n",
        "        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]\n",
        "        soft_attn_weights = F.softmax(attn_weights, 1)\n",
        "\n",
        "        # context : [batch_size, n_hidden * num_directions(=2)]\n",
        "        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)\n",
        "        return context, soft_attn_weights \n",
        "\n",
        "    def forward(self, X):\n",
        "        '''\n",
        "        X: [batch_size, seq_len]\n",
        "        '''\n",
        "        input = self.embedding(X) # input : [batch_size, seq_len, embedding_dim]\n",
        "        input = input.transpose(0, 1) # input : [seq_len, batch_size, embedding_dim]\n",
        "\n",
        "        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n",
        "        output, (final_hidden_state, final_cell_state) = self.lstm(input)\n",
        "        output = output.transpose(0, 1) # output : [batch_size, seq_len, n_hidden]\n",
        "        attn_output, attention = self.attention_net(output, final_hidden_state)\n",
        "        return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]\n",
        "\n",
        "model = BiLSTM_Attention().to(device)\n",
        "criterion = nn.CrossEntropyLoss().to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0WS_AKj4hR22",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "ecfda77f-2abc-4688-ae64-a87559e60449"
      },
      "source": [
        "# Training\n",
        "for epoch in range(5000):\n",
        "    for x, y in loader:\n",
        "        x, y = x.to(device), y.to(device)\n",
        "        pred, attention = model(x)\n",
        "        loss = criterion(pred, y)\n",
        "        if (epoch + 1) % 1000 == 0:\n",
        "            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "      \n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1000 cost = 0.007154\n",
            "Epoch: 1000 cost = 0.002582\n",
            "Epoch: 2000 cost = 0.001240\n",
            "Epoch: 2000 cost = 0.000944\n",
            "Epoch: 3000 cost = 0.000314\n",
            "Epoch: 3000 cost = 0.000074\n",
            "Epoch: 4000 cost = 0.000135\n",
            "Epoch: 4000 cost = 0.000015\n",
            "Epoch: 5000 cost = 0.000017\n",
            "Epoch: 5000 cost = 0.000015\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aqDpw1VRtuuV",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "697fa85d-d794-40f4-d3f3-3a9b8f07fd2e"
      },
      "source": [
        "# Test\n",
        "test_text = 'i hate me'\n",
        "tests = [np.asarray([word2idx[n] for n in test_text.split()])]\n",
        "test_batch = torch.LongTensor(tests).to(device)\n",
        "\n",
        "# Predict\n",
        "predict, _ = model(test_batch)\n",
        "predict = predict.data.max(1, keepdim=True)[1]\n",
        "if predict[0][0] == 0:\n",
        "    print(test_text,\"is Bad Mean...\")\n",
        "else:\n",
        "    print(test_text,\"is Good Mean!!\")\n",
        "    \n",
        "# fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]\n",
        "# ax = fig.add_subplot(1, 1, 1)\n",
        "# ax.matshow(attention.cpu().data, cmap='viridis')\n",
        "# ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)\n",
        "# ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})\n",
        "# plt.show()"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "i hate me is Good Mean!!\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IuEy01KDv0cr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}